Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warm-up functionality with tensor to trajectory helper functions #224

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

alexandrelarouche
Copy link
Contributor

@alexandrelarouche alexandrelarouche commented Dec 11, 2024

I'm hoping to get your insight on how to make this better. Some parts of the code are sketchy and have been highlighted with a WARNING tag in the comments.

Add function to generate trajectories from states and actions tensors
Add function to crudely warmup a GFN (early stopping or other tricks not included)
…ging since every other GFN loss method returns tensor
@saleml
Copy link
Collaborator

saleml commented Jan 11, 2025

Thank you for the PR. Could you please elaborate a little bit more on it? What use-case are you targeting? Where do you use the new functions? Is there a way to test them and see their effects in the repo?

Thanks

@alexandrelarouche
Copy link
Contributor Author

alexandrelarouche commented Jan 12, 2025

Hi Salem!

Yes, sorry, I contacted Joseph via Slack prior to the PR, but I should've given more detail on here.

These functions are provided as a means to generate warmup trajectories from external state-action-tensors (e.g.\ expert knowledge, or another algorithm's output). My rationale for PR'ing these simple functions is that I found the whole process to be non-trivial when looking at the sources/docs (namely, watch for the WARNING tags) and I thought other users could benefit from having either a full implementation or an example.

def states_actions_tns_to_traj(
    states_tns: torch.Tensor,
    actions_tns: torch.Tensor,
    env: DiscreteEnv,
) -> Trajectories:

is a utility function that maps state-tensors and actions to a Trajectories object. Effectively, this is a translation function for a prior that comes from outside of the torch-gfn ecosystem, that would not already be wrapped in a Trajectories.

def warm_up(
    replay_buf: ReplayBuffer,
    optimizer: torch.optim.Optimizer,
    gfn: GFlowNet,
    env: Env,
    n_steps: int,
    batch_size: int,
    recalculate_all_logprobs=True,
):

is a training loop over a fixed replay buffer, but does not assume that some log-probs were computed in the Trajectories generated by the prior. Anyone could implement their own version, I simply provided mine as a crude example (there is no early stopping here, or any training trick). The important/tricky bit, if I remember correctly lies in settings recalculate_all_logprobs=True for TB-GFNs, because the states_actions_tns_to_traj function creates some dummy log prob tensors (since it is not expected that the prior would provide those).

I can write some unit-tests for the states_actions_tns_to_traj, as I think it is the trickier function in this duo. I can also create the docstrings (which I thought I had provided, my bad).

If you have any other feedback, send it my way so that we can implement it and follow your philosophy more closely.

Edit: I clarified why the warm-up function was important to this PR

@saleml
Copy link
Collaborator

saleml commented Jan 22, 2025

Thank you for the PR
The states_actions_tns_to_traj function needs better input validation and documentation. Here's how I would modify it:

    if states_tns.shape[1:] != env.state_shape:
        raise ValueError(
            f"states_tns state dimensions must match env.state_shape {env.state_shape}, "
            f"got shape {states_tns.shape[1:]}"
        )
    if len(actions_tns.shape) != 1:
        raise ValueError(f"actions_tns must be 1D, got batch_shape {actions_tns.shape}")
    if states_tns.shape[0] != actions_tns.shape[0]:
        raise ValueError(
            f"states and actions must have same trajectory length, got "
            f"states: {states_tns.shape[0]}, actions: {actions_tns.shape[0]}"
        )

    # ... rest of the code ...

Possible docstrign to add:

   
   This utility function helps integrate external data (e.g. expert demonstrations) 
   into the GFlowNet framework by converting raw tensors into proper Trajectories objects.
   
   Args:
       states_tns: Tensor of shape [traj_len, *state_shape] containing states for a single trajectory
       actions_tns: Tensor of shape [traj_len] containing discrete action indices
       env: The discrete environment that defines the state/action spaces
       
   Returns:
       Trajectories: A Trajectories object containing the converted states and actions
       
   Raises:
       ValueError: If tensor shapes are invalid or inconsistent
   """

For the warm_up, a docstring would be appreciated. I am not sure why gfn.loss admits an extra argument for TBloss. I will investigate it.

for epoch in t:
training_trajs = replay_buf.sample(batch_size)
optimizer.zero_grad()
if isinstance(gfn, TBGFlowNet):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with #231 , this could be changed to a cleaner test (if it's a PFBasedGFlowNet)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Seeing your commit, I think this would be cleaner.

Add doscrings
Add input validation (as proposed by saleml)
Add PFBasedGFlowNet verification instead of only TBGFNs (needs merge
of GFNOrg#231)
@josephdviviano josephdviviano self-assigned this Jan 24, 2025
Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, first I want to apologize for taking so long to review this. I hit a bit of a lull over Dec / early Jan and have been playing catchup.

This is a really nice PR, and a feature I'd be excited to use myself in some of the applications I've been looking at. My only request revolves around the use of the dummy log_probs - if our library is working properly, it should function as intended using log_probs=None, and if not, we should fix the downstream elements if they're misbehaving, because this is the intended use of the Trajectories container.

Awesome contribution, thank you very much!

# WARNING: This is sketchy. Create dummy values to avoid indexing / batch shape errors.
# WARNING: Assumes gfn.loss() uses recalculate_all_logprobs=True (thus only PFBasedGFlowNet are supported right now)!!
# WARNING: To reviewers: Can we bypass needing to define this?
log_probs = torch.full(size=(len(actions), 1), fill_value=0, dtype=torch.float)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log_probs can be None - which will trigger a recalculate downstream (or, if it doesn't, we should fix that).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In other words you can remove 129-132.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! I will do some testing this week to see if it behaves as I expect before removing.

actions,
log_rewards=log_rewards,
when_is_done=when_is_done,
log_probs=log_probs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

log_probs=None

env.actions_from_tensor(a.unsqueeze(0).unsqueeze(0)) for a in actions_tns
]

# stack is a class method, so actions[0] is just to access a class instance and is not particularly relevant
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really appreciate this comment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants